package org.canaan.mybatis;

import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Properties;

/**
 * 分页插件
 */

@Intercepts({@org.apache.ibatis.plugin.Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class,Integer.class})})
public class MysqlPagingPlugin implements Interceptor {
    private final Logger logger = LoggerFactory.getLogger(MysqlPagingPlugin.class);
    private String dialect   = ""; //数据库
    private String pageSqlId = ""; //mapper.xml中需要拦截的ID(正则匹配)

    public Object intercept(Invocation invocation) throws Throwable {
        RoutingStatementHandler statementHandler = (RoutingStatementHandler) invocation.getTarget();
        StatementHandler delegate = (StatementHandler) ReflectHelper.getFieldValue(statementHandler, "delegate");
        MappedStatement mappedStatement = (MappedStatement) ReflectHelper.getFieldValue(delegate, "mappedStatement");
        if (mappedStatement.getId().matches(pageSqlId)) {
            BoundSql boundSql = delegate.getBoundSql();
            String sql = boundSql.getSql();

            //取得记录总数
            Page page = (Page) boundSql.getParameterObject();
            setTotalRecord(page, mappedStatement, (Connection) invocation.getArgs()[0]);

            //取得Mysql的分页语句
            String pagingSql = getMysqlPageSql(page, sql);
            ReflectHelper.setFieldValue(boundSql, "sql", pagingSql);
        }
        return invocation.proceed();
    }


    /**
     * Plugin.wrap生成拦截代理对象
     */
    public Object plugin(Object o) {
        if (o instanceof StatementHandler) {
            return Plugin.wrap(o, this);
        } else {
            return o;
        }
    }

    public void setProperties(Properties properties) {

    }

    /**
     * 给当前的参数对象page设置总记录数
     *
     * @param page            Mapper映射语句对应的参数对象
     * @param mappedStatement Mapper映射语句
     * @param connection      当前的数据库连接
     */
    private void setTotalRecord(Page page,
                                MappedStatement mappedStatement, Connection connection) {

        //获取对应的BoundSql，这个BoundSql其实跟我们利用StatementHandler获取到的BoundSql是同一个对象。
        //delegate里面的boundSql也是通过mappedStatement.getBoundSql(paramObj)方法获取到的。
        BoundSql boundSql = mappedStatement.getBoundSql(page);
        //获取到我们自己写在Mapper映射语句中对应的Sql语句
        String sql = boundSql.getSql();
        //通过查询Sql语句获取到对应的计算总记录数的sql语句
        String countSql = this.getCountSql(sql);
        //通过BoundSql获取对应的参数映射
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        //利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。
        BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, page);
        //通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象
        ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, page, countBoundSql);
        //通过connection建立一个countSql对应的PreparedStatement对象。
        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            pstmt = connection.prepareStatement(countSql);
            //通过parameterHandler给PreparedStatement对象设置参数
            parameterHandler.setParameters(pstmt);
            //之后就是执行获取总记录数的Sql语句和获取结果了。
            rs = pstmt.executeQuery();
            if (rs.next()) {
                long count = rs.getInt(1);
                if (page == null) {
                    page = new Page();
                }
                page.setTotalRecNum(count);//设置总记录数
            }
        } catch (SQLException e) {
            logger.error("统计语句出错：{}",countSql,e);
        } finally {
            try {
                if (rs != null)
                    rs.close();
                if (pstmt != null)
                    pstmt.close();
            } catch (SQLException e) {
                logger.error("统计语句出错：资源关闭失败",e);
            }
        }
    }

    /**
     * 根据原Sql语句获取对应的查询总记录数的Sql语句
     *
     * @param sql
     * @return
     */
    private String getCountSql(String sql) {
        sql = sql.trim();
        if (StringUtils.endsWith(sql, ";")) {
            sql =   StringUtils.removeEnd(sql, ";");
        }
        return "select count(1) from (" + sql + ") a";
    }

    /**
     * 根据page对象获取对应的分页查询Sql语句，这里只做了两种数据库类型，Mysql和Oracle
     * 其它的数据库都 没有进行分页
     *
     * @param page 分页对象
     * @param sql  原sql语句
     * @return
     */
    private String getPageSql(Page page, String sql) {
        if ("mysql".equalsIgnoreCase(dialect)) {
            return getMysqlPageSql(page, sql);
        }
        return sql;
    }

    /**
     * 获取Mysql数据库的分页查询语句
     *
     * @param page      分页对象
     * @param sql  包含原sql语句的StringBuffer对象
     * @return Mysql数据库分页语句
     */
    private String getMysqlPageSql(Page page, String sql) {
        sql = sql.trim();
        if (StringUtils.endsWith(sql, ";")) {
           sql =  StringUtils.removeEnd(sql, ";");
        }

        return sql+" limit "+page.getOffSet()+" , "+page.getPageSize();
    }

    /**
     * 设置注册拦截器时设定的属性
     */

    public String getDialect() {
        return dialect;
    }

    public void setDialect(String dialect) {
        this.dialect = dialect;
    }

    public String getPageSqlId() {
        return pageSqlId;
    }

    public void setPageSqlId(String pageSqlId) {
        this.pageSqlId = pageSqlId;
    }


}